{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7680\n",
      "1640\n",
      "1660\n"
     ]
    }
   ],
   "source": [
    "#定义数据\n",
    "class SurnameDataset(Dataset):\n",
    "    def __init__(self, part):\n",
    "        data = pd.read_csv('./data/surnames/数字化数据.csv')\n",
    "        data = data[data.part == part]\n",
    "        self.data = data\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.data.iloc[i, 0], self.data.iloc[i, 1]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "\n",
    "train_dataset = SurnameDataset(part='train')\n",
    "val_dataset = SurnameDataset(part='val')\n",
    "test_dataset = SurnameDataset(part='test')\n",
    "\n",
    "print(len(train_dataset))\n",
    "print(len(val_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[10,  3, 16,  2, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],\n",
      "        [ 4, 20, 13, 27, 18, 17,  6,  9,  0,  0,  0,  0,  0,  0,  0],\n",
      "        [20,  6,  1, 15,  4, 13, 18,  9,  2, 27,  0,  0,  0,  0,  0],\n",
      "        [12, 13, 11, 16,  2,  6, 10,  0,  0,  0,  0,  0,  0,  0,  0],\n",
      "        [10,  3,  7, 15, 20, 17,  8,  8, 13,  0,  0,  0,  0,  0,  0]]) torch.Size([100, 15])\n",
      "tensor([ 6, 13, 13, 12, 12]) torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "def to_tensor(data):\n",
    "    N = len(data)\n",
    "    #N句话,每句话15个词\n",
    "    xs = np.zeros((N, 15))\n",
    "    ys = np.empty(N)\n",
    "    for i in range(N):\n",
    "        x, y = data[i]\n",
    "        ys[i] = y\n",
    "\n",
    "        x = x.split(',') + [0] * 15\n",
    "        x = x[:15]\n",
    "        xs[i] = x\n",
    "\n",
    "    return torch.LongTensor(xs), torch.LongTensor(ys)\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "train_dataloader = DataLoader(dataset=train_dataset,\n",
    "                              batch_size=100,\n",
    "                              shuffle=True,\n",
    "                              drop_last=True,\n",
    "                              collate_fn=to_tensor)\n",
    "\n",
    "val_dataloader = DataLoader(dataset=val_dataset,\n",
    "                            batch_size=100,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True,\n",
    "                            collate_fn=to_tensor)\n",
    "\n",
    "test_dataloader = DataLoader(dataset=test_dataset,\n",
    "                             batch_size=100,\n",
    "                             shuffle=True,\n",
    "                             drop_last=True,\n",
    "                             collate_fn=to_tensor)\n",
    "\n",
    "#遍历数据\n",
    "for i, data in enumerate(train_dataloader):\n",
    "    x, y = data\n",
    "    print(x[:5], x.shape)\n",
    "    print(y[:5], y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.3216, -0.2514,  0.0087, -0.2214, -0.4007, -0.0327, -0.3790, -0.0906,\n",
       "         -0.2225,  0.2062,  0.2101, -0.0179,  0.3571, -0.1342,  0.1497,  0.1468,\n",
       "         -0.0438,  0.2620],\n",
       "        [ 0.1919,  0.1054, -0.1934, -0.0700, -0.1049,  0.1769, -0.2508, -0.0366,\n",
       "          0.0531,  0.0722, -0.2265, -0.1512,  0.0072, -0.0551,  0.0144,  0.1699,\n",
       "         -0.0732,  0.1712]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义网络模型\n",
    "class SurnameClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SurnameClassifier, self).__init__()\n",
    "\n",
    "        self.embedding = nn.Embedding(num_embeddings=30,\n",
    "                                      embedding_dim=50,\n",
    "                                      padding_idx=0)\n",
    "\n",
    "        self.rnn_cell = nn.RNNCell(50, 100)\n",
    "\n",
    "        self.fc1 = nn.Linear(in_features=100, out_features=100)\n",
    "        self.fc2 = nn.Linear(in_features=100, out_features=18)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        b = x.shape[0]\n",
    "\n",
    "        #[b,15] -> [b,15,20]\n",
    "        embed = self.embedding(x)\n",
    "\n",
    "        #[b,15,20] -> [b,30]\n",
    "        out = torch.zeros((b, 100))\n",
    "        for i in range(15):\n",
    "            out = self.rnn_cell(embed[:, i, :], out)\n",
    "\n",
    "        #[b,30] -> [b,18]\n",
    "        out = F.relu(self.fc1(F.dropout(out, 0.5)))\n",
    "        out = self.fc2(F.dropout(out, 0.5))\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = SurnameClassifier()\n",
    "model(torch.ones(2, 15).long())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.045"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test(dataloader):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for i, data in enumerate(dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        y_pred = model(x)\n",
    "        y_pred = y_pred.argmax(dim=1)\n",
    "\n",
    "        correct += (y_pred == y).sum().item()\n",
    "        total += len(y)\n",
    "\n",
    "    return correct / total\n",
    "\n",
    "\n",
    "test(val_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.1757946014404297 0.253125\n",
      "1 2.196861982345581 0.28125\n",
      "2 1.7546985149383545 0.453125\n",
      "3 1.6095668077468872 0.495\n",
      "4 1.5190417766571045 0.52375\n",
      "5 1.5627827644348145 0.54125\n",
      "6 1.4074265956878662 0.56\n",
      "7 1.3997132778167725 0.586875\n",
      "8 1.5140800476074219 0.61625\n",
      "9 1.5021535158157349 0.620625\n",
      "10 1.3012571334838867 0.610625\n",
      "11 1.3222830295562744 0.631875\n",
      "12 1.3981198072433472 0.6525\n",
      "13 1.2540010213851929 0.64\n",
      "14 1.4072508811950684 0.63625\n",
      "15 1.2692131996154785 0.64625\n",
      "16 1.2988580465316772 0.656875\n",
      "17 1.118596076965332 0.635625\n",
      "18 1.3047208786010742 0.6625\n",
      "19 1.0393601655960083 0.655\n"
     ]
    }
   ],
   "source": [
    "loss_func = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(20):\n",
    "    for i, data in enumerate(train_dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "\n",
    "        loss = loss_func(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        accurecy = test(val_dataloader)\n",
    "        print(epoch, loss.item(), accurecy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.655625"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(test_dataloader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "138px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
